"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script creates a temporal network (ContTempNetwork) for the APS dataset
from the APS data.

saves the network as `aps_monthly_lcc_net.pickle`

"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import pandas as pd
from TemporalNetwork import ContTempNetwork
import time
from multiprocessing import Pool
import numpy as np


df_edges = pd.read_csv('../data/aps/all_journals_disamb_edges.csv.gz', index_col=0)

nproc = 7
verbose = True

df_authors = pd.read_csv('../data/aps/APS_authors.csv', index_col=0)

# length of collaboration (days)
colab_length = 365
raise Exception
#%%



net = ContTempNetwork(source_nodes=df_edges.n1.iloc[:].tolist(),
                      target_nodes=df_edges.n2.iloc[:].tolist(),
                      starting_times=(df_edges.date.iloc[:] - df_edges.date.iloc[0]).dt.days,
                      ending_times=(df_edges.date.iloc[:] - df_edges.date.iloc[0]).dt.days + colab_length)

net.events_table['doi'] = df_edges.doi.tolist()
net.events_table['pub_date'] = df_edges.date.tolist()

#%% merge overlapping

A = net.compute_static_adjacency_matrix()
    
A = A + A.T

def worker(n1):                        
    # loop over nodes

    evs_to_delete = [] #Id of events to delete
    evs_to_prolong = [] #(Id, new_end_time) for events to prolong
    
    for n2 in (A[n1,:] > 0).nonzero()[1]:
        
        
        mask_12 = np.logical_and(net.events_table.source_nodes.values == n1, 
                                 net.events_table.target_nodes.values == n2)
        
        mask_21 = np.logical_and(net.events_table.source_nodes.values == n2, 
                                 net.events_table.target_nodes.values == n1)
        
        #sort by starting times
        evs = net.events_table.loc[np.logical_or(mask_12, mask_21)].\
                        sort_values(by=['starting_times', 'ending_times'])
        
        evs_list = list(evs.itertuples())
        
        # event to compare
        ev1 = evs_list[0]
        merged = 0
        for k in range(1,len(evs_list)):
            ev2 = evs_list[k]
            # if ev2 overlaps with ev1, merge them, otherwise ev2 becomes
            # ev1
            if ev2.starting_times < ev1.ending_times:
                #merge
                evs_to_delete.append(ev2.Index)
                evs_to_prolong.append((ev1.Index, ev2.ending_times))
                ev1._replace(ending_times=ev2.ending_times)
                merged += 1
            else:
                ev1 = ev2
        # if verbose and merged > 0 :
        #     print(f"PID:{os.getpid()}:", f'n1,n2 ({n1},{n2}): {merged} merged')
    
    if verbose:
        print(f"PID:{os.getpid()}:", f'done {n1}')
        
    return evs_to_delete, evs_to_prolong

#%%
if __name__ == '__main__':
    t0 = time.time()
    with Pool(nproc) as p:
        
        res = p.map(worker, range(net.num_nodes))
       
    #%%    
    events_to_keep = np.ones(net.events_table.shape[0],dtype=bool)                
    
    for evs_to_delete, evs_to_prolong in res:
        if len(evs_to_delete) > 0:
            events_to_keep[evs_to_delete] = False
            
            for evID, tend in evs_to_prolong:
                net.events_table.loc[evID,'ending_times'] = tend
            
    num_merged = (events_to_keep == False).sum()                        
    print('PID ', os.getpid(), ' : ','merged ', 
          num_merged, ' events')
                            
    net.events_table = net.events_table.loc[events_to_keep]
    
    net.events_table.reset_index(inplace=True, drop=True)
                               
    net.num_nodes = net.node_array.shape[0]
    
    net.num_events = net.events_table.shape[0]
    
    net.start_time = net.events_table.starting_times.min()
    
    net.end_time = net.events_table.ending_times.max()
    
    net._compute_time_grid()
    
    filename = '../paper_data/aps/aps_net.pickle'
    
    net.save(filename=filename,
             attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes'])
    
    print('finished', time.time() - t0)
    


#%% make monthly resolution network

filename = '../paper_data/aps/aps_net.pickle'
net = ContTempNetwork.load(filename=filename,
         attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])


monthly_net = ContTempNetwork(events_table=net.events_table.copy(),
                              relabel_nodes=False,
                              node_to_label_dict=net.node_to_label_dict)

from datetime import timedelta, datetime
date_start = net.events_table.pub_date.min() - timedelta(days=colab_length)

def start_of_month(d):
    return datetime(d.year, d.month, 1)

start_months = []
end_months = []
for edge in monthly_net.events_table.itertuples():
    d_start = start_of_month(date_start + timedelta(edge.starting_times))
    d_end = start_of_month(date_start + timedelta(edge.ending_times))
    
    # start/end times in days with month resolution
    start_months.append((d_start - start_of_month(date_start)).days)
    end_months.append((d_end - start_of_month(date_start)).days)
    
    
monthly_net.events_table['start_months'] = start_months    
monthly_net.events_table['end_months'] = end_months


monthly_net.events_table['starting_times'] = monthly_net.events_table['start_months']
monthly_net.events_table['ending_times'] = monthly_net.events_table['end_months']

monthly_net._compute_time_grid()


# remove nodes not in LCC
A = monthly_net.compute_static_adjacency_matrix()

from scipy.sparse.csgraph import connected_components


ncomp, label = connected_components(A, directed=False, connection='weak')

lcc_lab, = np.where(np.bincount(label) == np.bincount(label).max())

nodes_not_in_lcc, = (label != lcc_lab).nonzero()

#remove nodes not in lcc and update label dict

old_dict = monthly_net.node_to_label_dict.copy()

monthly_net.events_table.drop(index=monthly_net.events_table.loc[monthly_net.events_table.source_nodes.isin(nodes_not_in_lcc)].index, inplace=True)
monthly_net.events_table.drop(index=monthly_net.events_table.loc[monthly_net.events_table.target_nodes.isin(nodes_not_in_lcc)].index, inplace=True)
source_labels = monthly_net.events_table.source_nodes.map(monthly_net.node_to_label_dict).tolist()
target_labels = monthly_net.events_table.target_nodes.map(monthly_net.node_to_label_dict).tolist()

#
net_final = ContTempNetwork(source_nodes=source_labels,
                            target_nodes=target_labels,
                            starting_times=monthly_net.events_table.starting_times.tolist(),
                            ending_times=monthly_net.events_table.ending_times.tolist(),
                            extra_attrs={'doi':monthly_net.events_table['doi'].tolist(),
                                         'pub_date':monthly_net.events_table['pub_date'].tolist()})

Af = net_final.compute_static_adjacency_matrix()
ncompf, labelf = connected_components(Af, directed=False, connection='weak')

assert ncompf == 1

net_final._compute_time_grid()

net_final.time_slices_bounds = net.time_slices_bounds
net_final.time_slices_bounds_datetimes = net.time_slices_bounds_datetimes


filename = '../paper_data/aps/aps_monthly_lcc_net.pickle'
net_final.save(filename=filename,
         attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])
